# Before running, install required packages:
{% if notebook %}

!
{%- else %}
#
{%- endif %}
 pip install numpy torch torchvision pytorch-ignite{% if visualization_tool == "Tensorboard" %} tensorboardX tensorboard{% endif %}{% if visualization_tool == "Weights & Biases" %} wandb{% endif %}{% if visualization_tool == "comet.ml" %} comet_ml{% endif %}{% if visualization_tool == "Aim" %} aim{% endif %}
{% if visualization_tool == "Weights & Biases" %}


# Also, you need to login to Weights & Biases on the terminal:
{% if notebook %}

! wandb login
{% else %}
# wandb login
{% endif %}
{% endif %}
{% if notebook %}


# ---
{% endif %}


{% if visualization_tool == "comet.ml" %}
from comet_ml import Experiment  # has to be 1st import
{% endif %}
import numpy as np
import torch
from torch import optim, nn
from torch.utils.data import DataLoader, TensorDataset
from torchvision import models, datasets, transforms
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
{% if data_format == "Image files" %}
import urllib
import zipfile
{% endif %}
{% if visualization_tool == "Tensorboard" or checkpoint %}
from datetime import datetime
{% endif %}
{% if visualization_tool == "Tensorboard" %}
from tensorboardX import SummaryWriter
{% elif visualization_tool == "Aim" %}
from aim import Session
{% elif visualization_tool == "Weights & Biases" %}
import wandb
{% endif %}
{% if checkpoint %}
from pathlib import Path
{% endif %}

{% if data_format == "Numpy arrays" %}
def fake_data():
    # 4 images of shape 1x16x16 with labels 0, 1, 2, 3
    return [np.random.rand(4, 1, 16, 16), np.arange(4)]

{% elif data_format == "Image files" %}
# COMMENT THIS OUT IF YOU USE YOUR OWN DATA.
# Download example data into ./data/image-data (4 image files, 2 for "dog", 2 for "cat").
url = "https://github.com/jrieke/traingenerator/raw/main/data/fake-image-data.zip"
zip_path, _ = urllib.request.urlretrieve(url)
with zipfile.ZipFile(zip_path, "r") as f:
    f.extractall("data")

{% endif %}

{{ header("Setup") }}
{% if data_format == "Numpy arrays" %}
# INSERT YOUR DATA HERE
# Expected format: [images, labels]
# - images has array shape (num samples, color channels, height, width)
# - labels has array shape (num samples, )
train_data = fake_data()  # required
val_data = fake_data()    # optional
test_data = None          # optional
{% elif data_format == "Image files" %}
# INSERT YOUR DATA HERE
# Expected format: One folder per class, e.g.
# train
# --- dogs
# |   +-- lassie.jpg
# |   +-- komissar-rex.png
# --- cats
# |   +-- garfield.png
# |   +-- smelly-cat.png
#
# Example: https://github.com/jrieke/traingenerator/tree/main/data/image-data
train_data = "data/image-data"  # required
val_data = "data/image-data"    # optional
test_data = None                # optional
{% elif data_format == "Public dataset"%}
# Dataset {{ dataset }} will be loaded further down.
{% endif %}

# Set up hyperparameters.
lr = {{ lr }}
batch_size = {{ batch_size }}
num_epochs = {{ num_epochs }}

# Set up logging.
{% if visualization_tool == "Tensorboard" or checkpoint %}
experiment_id = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
{% endif %}
{% if visualization_tool == "Tensorboard" %}
writer = SummaryWriter(logdir=f"logs/{experiment_id}")
{% elif visualization_tool == "Aim" %}
aim_session = Session({% if aim_experiment %}experiment="{{ aim_experiment }}"{% endif %})
aim_session.set_params({"lr": lr, "batch_size": batch_size, "num_epochs": num_epochs}, name="hparams")
{% elif visualization_tool == "Weights & Biases" %}
wandb.init(
{% if wb_project %}
    project="{{ wb_project }}", 
{% endif %}
{% if wb_name %}
    name="{{ wb_name }}", 
{% endif %}
    config={"lr": lr, "batch_size": batch_size, "num_epochs": num_epochs}
)
{% elif visualization_tool == "comet.ml" %}
experiment = Experiment("{{ comet_api_key }}"{% if comet_project %}, project_name="{{ comet_project }}"{% endif %})
{% endif %}
{% if checkpoint %}
checkpoint_dir = Path(f"checkpoints/{experiment_id}")
checkpoint_dir.mkdir(parents=True, exist_ok=True)
{% endif %}
print_every = {{ print_every }}  # batches

# Set up device.
{% if gpu %}
use_cuda = torch.cuda.is_available()
{% else %}
use_cuda = False
{% endif %}
device = torch.device("cuda" if use_cuda else "cpu")


{% if data_format == "Public dataset" %}
{{ header("Dataset & Preprocessing") }}
def load_data(train):
    # Download and transform dataset.
    transform = transforms.Compose([
        transforms.Resize(256), 
        transforms.CenterCrop(224), 
        transforms.ToTensor(), 
        {% if dataset == "MNIST" or dataset == "FashionMNIST"%}
        transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # grayscale to RGB
        {% endif %}
    ])
    dataset = datasets.{{ dataset }}("./data", train=train, download=True, transform=transform)

    # Wrap in data loader.
    if use_cuda:
        kwargs = {"pin_memory": True, "num_workers": 1}
    else:
        kwargs = {}
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=train, **kwargs)
    return loader

train_loader = load_data(train=True)
val_loader = None
test_loader = load_data(train=False)
{% else %}
{{ header("Preprocessing") }}
def preprocess(data, name):
    if data is None:  # val/test can be empty
        return None

    {% if data_format == "Image files" %}
    # Read image files to pytorch dataset.
    transform = transforms.Compose([
        transforms.Resize(256), 
        transforms.CenterCrop(224), 
        transforms.ToTensor(), 
        {# TODO: Maybe add normalization option even if model is not pretrained #}
        {% if pretrained %}
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        {% endif %}
    ])
    dataset = datasets.ImageFolder(data, transform=transform)
    {% elif data_format == "Numpy arrays" %}
    images, labels = data

    # Rescale images to 0-255 and convert to uint8.
    # Note: This is done for each dataset individually, which is usually ok if all 
    # datasets look similar. If not, scale all datasets based on min/ptp of train set.
    images = (images - np.min(images)) / np.ptp(images) * 255
    images = images.astype(np.uint8)

    # If images are grayscale, convert to RGB by duplicating channels.
    if images.shape[1] == 1:
        images = np.stack((images[:, 0],) * 3, axis=1)

    # Resize images and transform images torch tensor.
    images = images.transpose((0, 2, 3, 1))  # channels-last, required for transforms.ToPILImage
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(256), 
        transforms.CenterCrop(224), 
        transforms.ToTensor(), 
        {% if pretrained %}
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        {% endif %}
    ])
    {# TODO: This is quite ugly and very inefficient #}
    images = torch.stack(list(map(transform, images)))

    # Convert labels to tensors.
    labels = torch.from_numpy(labels).long()

    # Construct dataset.
    dataset = TensorDataset(images, labels)
    {% endif %}

    # Wrap in data loader.
    {% if gpu %}
    if use_cuda:
        kwargs = {"pin_memory": True, "num_workers": 1}
    else:
        kwargs = {}
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=(name=="train"), **kwargs)
    {% else %}
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=(name=="train"))
    {% endif %}
    return loader

train_loader = preprocess(train_data, "train")
val_loader = preprocess(val_data, "val")
test_loader = preprocess(test_data, "test")
{% endif %}


{{ header("Model") }}
# Set up model, loss, optimizer.
model = models.{{ model_func }}(pretrained={{ pretrained }})
{# TODO: Maybe enable this by default, so that people can adapt num_classes afterward. #}
{% if num_classes != 1000 %}
num_classes = {{ num_classes }}
{% if "resnet" in model_func %}
model.fc = torch.nn.Linear(in_features=model.fc.in_features, out_features=num_classes, bias=True)
{% elif "alexnet" in model_func or "vgg" in model_func %}
model.classifier[-1] = torch.nn.Linear(in_features=model.classifier[-1].in_features, out_features=num_classes, bias=True)
{% elif "densenet" in model_func %}
model.classifier = torch.nn.Linear(in_features=model.classifier.in_features, out_features=num_classes, bias=True)
{% endif %}
{% endif %}
model = model.to(device)
loss_func = nn.{{ loss }}()
optimizer = optim.{{ optimizer }}(model.parameters(), lr=lr)

{% if visualization_tool == "Weights & Biases" %}
# Log gradients and model parameters to W&B.
wandb.watch(model)

{% endif %}

{{ header("Training") }}
# Set up pytorch-ignite trainer and evaluator.
trainer = create_supervised_trainer(
    model,
    optimizer,
    loss_func,
    device=device,
)
{# TODO: Atm, the train metrics get accumulated, see torch_models.py #}
metrics = {
    "accuracy": Accuracy(),
    "loss": Loss(loss_func),
}
evaluator = create_supervised_evaluator(
    model, metrics=metrics, device=device
)

@trainer.on(Events.ITERATION_COMPLETED(every=print_every))
def log_batch(trainer):
    batch = (trainer.state.iteration - 1) % trainer.state.epoch_length + 1
    print(
        f"Epoch {trainer.state.epoch} / {num_epochs}, "
        f"batch {batch} / {trainer.state.epoch_length}: "
        f"loss: {trainer.state.output:.3f}"
    )

@trainer.on(Events.EPOCH_COMPLETED)
def log_epoch(trainer):
    print(f"Epoch {trainer.state.epoch} / {num_epochs} average results: ")

    def log_results(name, metrics, epoch):
        print(
            f"{name + ':':6} loss: {metrics['loss']:.3f}, "
            f"accuracy: {metrics['accuracy']:.3f}"
        )
        {% if visualization_tool == "Tensorboard" %}
        writer.add_scalar(f"{name}_loss", metrics["loss"], epoch)
        writer.add_scalar(f"{name}_accuracy", metrics["accuracy"], epoch)
        {% elif visualization_tool == "Aim" %}
        aim_session.track(metrics["loss"], name="loss", subset=name, epoch=epoch)
        aim_session.track(metrics["accuracy"], name="accuracy", subset=name, epoch=epoch)
        {% elif visualization_tool == "Weights & Biases" %}
        wandb.log({f"{name}_loss": metrics["loss"], f"{name}_accuracy": metrics["accuracy"]})
        {% elif visualization_tool == "comet.ml" %}
        experiment.log_metric(f"{name}_loss", metrics["loss"])
        experiment.log_metric(f"{name}_accuracy", metrics["accuracy"])
        {% endif %}

    # Train data.
    evaluator.run(train_loader)
    log_results("train", evaluator.state.metrics, trainer.state.epoch)
    
    # Val data.
    if val_loader:
        evaluator.run(val_loader)
        log_results("val", evaluator.state.metrics, trainer.state.epoch)

    # Test data.
    if test_loader:
        evaluator.run(test_loader)
        log_results("test", evaluator.state.metrics, trainer.state.epoch)

    print()
    print("-" * 80)
    print()

{# TODO: Maybe use this instead: https://pytorch.org/ignite/handlers.html#ignite.handlers.ModelCheckpoint #}
{% if checkpoint %}
@trainer.on(Events.EPOCH_COMPLETED)
def checkpoint_model(trainer):
    torch.save(model, checkpoint_dir / f"model-epoch{trainer.state.epoch}.pt")

{% endif %}
# Start training.
trainer.run(train_loader, max_epochs=num_epochs)
{% if visualization_tool == "Weights & Biases" %}
wandb.finish()
{% endif %}
